{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "369c3444",
   "metadata": {},
   "source": [
    "# ReadtheDocs Retrieval Augmented Generation (RAG) using Zilliz Free Tier"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f6ffd11a",
   "metadata": {},
   "source": [
    "In this notebook, we are going to use Milvus documentation pages to create a chatbot about our product.  The chatbot is going to follow RAG steps to retrieve chunks of data using Semantic Vector Search, then the Question + Context will be fed as a Prompt to a LLM to generate an answer.\n",
    "\n",
    "Many RAG demos use OpenAI for the Embedding Model and ChatGPT for the Generative AI model.  **In this notebook, we will demo a fully open source RAG stack.**\n",
    "\n",
    "Using open-source Q&A with retrieval saves money since we make free calls to our own data almost all the time - retrieval, evaluation, and development iterations.  We only make a paid call to OpenAI once for the final chat generation step. \n",
    "\n",
    "<div>\n",
    "<img src=\"../../images/rag_image.png\" width=\"80%\"/>\n",
    "</div>\n",
    "\n",
    "Let's get started!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "d7570b2e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# For colab install these libraries in this order:\n",
    "# !pip install pymilvus, langchain, torch, transformers, python-dotenv\n",
    "\n",
    "# Import common libraries.\n",
    "import sys, os, time, pprint\n",
    "import numpy as np\n",
    "\n",
    "# Import custom functions for splitting and search.\n",
    "sys.path.append(\"..\")  # Adds higher directory to python modules path.\n",
    "import milvus_utilities as _utils"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e059b674",
   "metadata": {},
   "source": [
    "## Download Milvus documentation to a local directory.\n",
    "\n",
    "The data we’ll use is our own product documentation web pages.  ReadTheDocs is an open-source free software documentation hosting platform, where documentation is written with the Sphinx document generator.\n",
    "\n",
    "The code block below downloads the web pages into a local directory called `rtdocs`.  \n",
    "\n",
    "I've already uploaded the `rtdocs` data folder to github, so you should see it if you cloned my repo."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "20dcdaf7",
   "metadata": {},
   "outputs": [],
   "source": [
    "# # Uncomment to download readthedocs pages locally.\n",
    "\n",
    "# DOCS_PAGE=\"https://pymilvus.readthedocs.io/en/latest/\"\n",
    "# !echo $DOCS_PAGE\n",
    "\n",
    "# # Specify encoding to handle non-unicode characters in documentation.\n",
    "# !wget -r -A.html -P rtdocs --header=\"Accept-Charset: UTF-8\" $DOCS_PAGE"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fb844837",
   "metadata": {},
   "source": [
    "## Start up a Zilliz free tier cluster.\n",
    "\n",
    "Code in this notebook uses fully-managed Milvus on [Ziliz Cloud free trial](https://cloud.zilliz.com/login).  \n",
    "  1. Choose the default \"Starter\" option when you provision > Create collection > Give it a name > Create cluster and collection.  \n",
    "  2. On the Cluster main page, copy your `API Key` and store it locally in a .env variable.  See note below how to do that.\n",
    "  3. Also on the Cluster main page, copy the `Public Endpoint URI`.\n",
    "\n",
    "💡 Note: To keep your tokens private, best practice is to use an **env variable**.  See [how to save api key in env variable](https://help.openai.com/en/articles/5112595-best-practices-for-api-key-safety). <br>\n",
    "\n",
    "In Jupyter, you also need a .env file (in same dir as notebooks) containing lines like this:\n",
    "- VARIABLE_NAME=value\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "0806d2db",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Type of server: zilliz_cloud\n"
     ]
    }
   ],
   "source": [
    "# STEP 1. CONNECT TO MILVUS\n",
    "\n",
    "# !pip install pymilvus #python sdk for milvus\n",
    "from pymilvus import connections, utility\n",
    "\n",
    "# Jupyter notebooks:\n",
    "# from dotenv import load_dotenv\n",
    "# load_dotenv()\n",
    "# TOKEN = os.getenv(\"ZILLIZ_API_KEY\")\n",
    "\n",
    "# Usual way:\n",
    "from dotenv import load_dotenv, find_dotenv\n",
    "_ = load_dotenv(find_dotenv()) # read local .env file\n",
    "TOKEN = os.environ[\"ZILLIZ_API_KEY\"]\n",
    "\n",
    "# Connect to Zilliz cloud using endpoint URI and API key TOKEN.\n",
    "# TODO change this.\n",
    "CLUSTER_ENDPOINT=\"https://in03-xxxx.api.gcp-us-west1.zillizcloud.com:443\"\n",
    "CLUSTER_ENDPOINT=\"https://in03-48a5b11fae525c9.api.gcp-us-west1.zillizcloud.com:443\"\n",
    "connections.connect(\n",
    "  alias='default',\n",
    "  #  Public endpoint obtained from Zilliz Cloud\n",
    "  uri=CLUSTER_ENDPOINT,\n",
    "  # API key or a colon-separated cluster username and password\n",
    "  token=TOKEN,\n",
    ")\n",
    "\n",
    "# Check if the server is ready and get colleciton name.\n",
    "print(f\"Type of server: {utility.get_server_version()}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b01d6622",
   "metadata": {},
   "source": [
    "## Load the Embedding Model checkpoint and use it to create vector embeddings\n",
    "**Embedding model:**  We will use the open-source [sentence transformers](https://www.sbert.net/docs/pretrained_models.html) available on HuggingFace to encode the documentation text.  We will download the model from HuggingFace and run it locally. \n",
    "\n",
    "Two model parameters of note below:\n",
    "1. EMBEDDING_LENGTH refers to the dimensionality or length of the embedding vector. In this case, the embeddings generated for EACH token in the input text will have the SAME length = 1024. This size of embedding is often associated with BERT-based models, where the embeddings are used for downstream tasks such as classification, question answering, or text generation. <br><br>\n",
    "2. MAX_SEQ_LENGTH is the maximum length the encoder model can handle for input sequences. In this case, if sequences longer than 512 tokens are given to the model, everything longer will be (silently!) chopped off.  This is the reason why a chunking strategy is needed to segment input texts into chunks with lengths that will fit in the model's input."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "dd2be7fd",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "device: cpu\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "No sentence-transformers model found with name /Users/christybergman/.cache/torch/sentence_transformers/WhereIsAI_UAE-Large-V1. Creating a new one with MEAN pooling.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<class 'sentence_transformers.SentenceTransformer.SentenceTransformer'>\n",
      "SentenceTransformer(\n",
      "  (0): Transformer({'max_seq_length': 512, 'do_lower_case': False}) with Transformer model: BertModel \n",
      "  (1): Pooling({'word_embedding_dimension': 1024, 'pooling_mode_cls_token': False, 'pooling_mode_mean_tokens': True, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False})\n",
      ")\n",
      "model_name: WhereIsAI/UAE-Large-V1\n",
      "EMBEDDING_LENGTH: 1024\n",
      "MAX_SEQ_LENGTH: 512\n"
     ]
    }
   ],
   "source": [
    "# STEP 2. DOWNLOAD AN OPEN SOURCE EMBEDDING MODEL.\n",
    "\n",
    "# Import torch.\n",
    "import torch\n",
    "from torch.nn import functional as F\n",
    "from sentence_transformers import SentenceTransformer\n",
    "\n",
    "# Initialize torch settings\n",
    "torch.backends.cudnn.deterministic = True\n",
    "DEVICE = torch.device('cuda:3' if torch.cuda.is_available() else 'cpu')\n",
    "print(f\"device: {DEVICE}\")\n",
    "\n",
    "# Load the model from huggingface model hub.\n",
    "# python -m pip install -U angle-emb\n",
    "model_name = \"WhereIsAI/UAE-Large-V1\"\n",
    "encoder = SentenceTransformer(model_name, device=DEVICE)\n",
    "print(type(encoder))\n",
    "print(encoder)\n",
    "\n",
    "# Get the model parameters and save for later.\n",
    "EMBEDDING_LENGTH = encoder.get_sentence_embedding_dimension()\n",
    "MAX_SEQ_LENGTH_IN_TOKENS = encoder.get_max_seq_length() \n",
    "# # Assume tokens are 3 characters long.\n",
    "# MAX_SEQ_LENGTH = MAX_SEQ_LENGTH_IN_TOKENS * 3\n",
    "# HF_EOS_TOKEN_LENGTH = 1 * 3\n",
    "# Test with 512 sequence length.\n",
    "MAX_SEQ_LENGTH = MAX_SEQ_LENGTH_IN_TOKENS\n",
    "HF_EOS_TOKEN_LENGTH = 1\n",
    "\n",
    "# Inspect model parameters.\n",
    "print(f\"model_name: {model_name}\")\n",
    "print(f\"EMBEDDING_LENGTH: {EMBEDDING_LENGTH}\")\n",
    "print(f\"MAX_SEQ_LENGTH: {MAX_SEQ_LENGTH}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Create a Milvus collection\n",
    "\n",
    "You can think of a collection in Milvus like a \"table\" in SQL databases.  The **collection** will contain the \n",
    "- **Schema** (or [no-schema Milvus client](https://milvus.io/docs/using_milvusclient.md)).  \n",
    "💡 You'll need the vector `EMBEDDING_LENGTH` parameter from your embedding model.\n",
    "Typical values are:\n",
    "   - 768 for sbert embedding models\n",
    "   - 1536 for ada-002 OpenAI embedding models\n",
    "- **Vector index** for efficient vector search\n",
    "- **Vector distance metric** for measuring nearest neighbor vectors\n",
    "- **Consistency level**\n",
    "In Milvus, transactional consistency is possible; however, according to the [CAP theorem](https://en.wikipedia.org/wiki/CAP_theorem), some latency must be sacrificed. 💡 Searching movie reviews is not mission-critical, so [`eventually`](https://milvus.io/docs/consistency.md) consistent is fine here.\n",
    "\n",
    "## Add a Vector Index\n",
    "\n",
    "The vector index determines the vector **search algorithm** used to find the closest vectors in your data to the query a user submits.  \n",
    "\n",
    "Most vector indexes use different sets of parameters depending on whether the database is:\n",
    "- **inserting vectors** (creation mode) - vs - \n",
    "- **searching vectors** (search mode) \n",
    "\n",
    "Scroll down the [docs page](https://milvus.io/docs/index.md) to see a table listing different vector indexes available on Milvus.  For example:\n",
    "- FLAT - deterministic exhaustive search\n",
    "- IVF_FLAT or IVF_SQ8 - Hash index (stochastic approximate search)\n",
    "- HNSW - Graph index (stochastic approximate search)\n",
    "- AUTOINDEX - Automatically determined based on OSS vs [Zilliz cloud](https://docs.zilliz.com/docs/autoindex-explained), type of GPU, size of data.\n",
    "\n",
    "Besides a search algorithm, we also need to specify a **distance metric**, that is, a definition of what is considered \"close\" in vector space.  In the cell below, the [`HNSW`](https://github.com/nmslib/hnswlib/blob/master/ALGO_PARAMS.md) search index is chosen.  Its possible distance metrics are one of:\n",
    "- L2 - L2-norm\n",
    "- IP - Dot-product\n",
    "- COSINE - Angular distance\n",
    "\n",
    "💡 Most use cases work better with normalized embeddings, in which case L2 is useless (every vector has length=1) and IP and COSINE are the same.  Only choose L2 if you plan to keep your embeddings unnormalized."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Successfully created collection: `MilvusDocs`\n",
      "{'collection_name': 'MilvusDocs', 'auto_id': True, 'num_shards': 1, 'description': '', 'fields': [{'field_id': 100, 'name': 'id', 'description': '', 'type': 5, 'params': {}, 'element_type': 0, 'auto_id': True, 'is_primary': True}, {'field_id': 101, 'name': 'vector', 'description': '', 'type': 101, 'params': {'dim': 1024}, 'element_type': 0}], 'aliases': [], 'collection_id': 446268198625172175, 'consistency_level': 3, 'properties': {}, 'num_partitions': 1, 'enable_dynamic_field': True}\n"
     ]
    }
   ],
   "source": [
    "# STEP 3. CREATE A NO-SCHEMA MILVUS COLLECTION AND DEFINE THE DATABASE INDEX.\n",
    "\n",
    "from pymilvus import MilvusClient\n",
    "\n",
    "# Set the Milvus collection name.\n",
    "COLLECTION_NAME = \"MilvusDocs\"\n",
    "\n",
    "# Add custom HNSW search index to the collection.\n",
    "# M = max number graph connections per layer. Large M = denser graph.\n",
    "# Choice of M: 4~64, larger M for larger data and larger embedding lengths.\n",
    "M = 16\n",
    "# efConstruction = num_candidate_nearest_neighbors per layer. \n",
    "# Use Rule of thumb: int. 8~512, efConstruction = M * 2.\n",
    "efConstruction = M * 2\n",
    "# Create the search index for local Milvus server.\n",
    "INDEX_PARAMS = dict({\n",
    "    'M': M,               \n",
    "    \"efConstruction\": efConstruction })\n",
    "index_params = {\n",
    "    \"index_type\": \"HNSW\", \n",
    "    \"metric_type\": \"COSINE\", \n",
    "    \"params\": INDEX_PARAMS\n",
    "    }\n",
    "\n",
    "# Use no-schema Milvus client uses flexible json key:value format.\n",
    "# https://milvus.io/docs/using_milvusclient.md\n",
    "mc = MilvusClient(\n",
    "    uri=CLUSTER_ENDPOINT,\n",
    "    # API key or a colon-separated cluster username and password\n",
    "    token=TOKEN)\n",
    "\n",
    "# Check if collection already exists, if so drop it.\n",
    "has = utility.has_collection(COLLECTION_NAME)\n",
    "if has:\n",
    "    drop_result = utility.drop_collection(COLLECTION_NAME)\n",
    "    print(f\"Successfully dropped collection: `{COLLECTION_NAME}`\")\n",
    "\n",
    "# Create the collection.\n",
    "mc.create_collection(COLLECTION_NAME, \n",
    "                     EMBEDDING_LENGTH,\n",
    "                     consistency_level=\"Eventually\", \n",
    "                     auto_id=True,  \n",
    "                     overwrite=True,\n",
    "                     # skip setting params below, if using AUTOINDEX\n",
    "                     params=index_params\n",
    "                    )\n",
    "\n",
    "print(f\"Successfully created collection: `{COLLECTION_NAME}`\")\n",
    "print(mc.describe_collection(COLLECTION_NAME))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c60423a5",
   "metadata": {},
   "source": [
    "## Chunking\n",
    "\n",
    "Before embedding, it is necessary to decide your chunk strategy, chunk size, and chunk overlap.  In this demo, I will use:\n",
    "- **Strategy** = Use markdown header hierarchies.  Keep markdown sections together unless they are too long.\n",
    "- **Chunk size** = Use the embedding model's parameter `MAX_SEQ_LENGTH`\n",
    "- **Overlap** = Rule-of-thumb 10-15%\n",
    "- **Function** = \n",
    "  - Langchain's `HTMLHeaderTextSplitter` to split markdown sections.\n",
    "  - Langchain's `RecursiveCharacterTextSplitter` to split up long reviews recursively.\n",
    "\n",
    "\n",
    "Notice below, each chunk is grounded with the document source page.  <br>\n",
    "In addition, header titles are kept together with the chunk of markdown text."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "30ef209a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "loaded 8 documents\n"
     ]
    }
   ],
   "source": [
    "# STEP 4. PREPARE DATA: CHUNK AND EMBED\n",
    "\n",
    "## Read docs into LangChain using v 0.0.322\n",
    "#!pip install langchain beautifulsoup4\n",
    "from langchain.document_loaders import ReadTheDocsLoader\n",
    "\n",
    "loader = ReadTheDocsLoader(\"../RAG/rtdocs/pymilvus.readthedocs.io/en/latest/\"\n",
    "                           , encoding=\"utf-8\"\n",
    "                           , features=\"html.parser\")\n",
    "docs = loader.load()\n",
    "\n",
    "num_documents = len(docs)\n",
    "print(f\"loaded {num_documents} documents\")\n",
    "\n",
    "# Langchain v 0.0.354\n",
    "# from langchain_community.document_loaders.readthedocs import ReadTheDocsLoader\n",
    "\n",
    "# # Create an instance of ReadTheDocsLoader\n",
    "# loader = ReadTheDocsLoader(\"../RAG/rtdocs/pymilvus.readthedocs.io/en/latest/\", \n",
    "#                            encoding=\"utf-8\")\n",
    "\n",
    "# # Load the documents\n",
    "# docs = loader.load()\n",
    "\n",
    "# num_documents = len(docs)\n",
    "# print(f\"loaded {num_documents} documents\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "chunk_size: 511, chunk_overlap: 51.0\n",
      "chunking time: 0.014161109924316406\n",
      "docs: 8, split into: 8\n",
      "split into chunks: 156, type: list of <class 'langchain.schema.document.Document'>\n",
      "\n",
      "Looking at a sample chunk...\n",
      "Installation¶ Installing via pip¶ PyMilvus is in the Python Package Index. PyMilvus only support pyt\n",
      "{'h1': 'Installation', 'h2': 'Installing via pip', 'source': '../RAG/rtdocs/pymilvus.readthedocs.io/en/latest/install.html'}\n"
     ]
    }
   ],
   "source": [
    "from langchain.text_splitter import HTMLHeaderTextSplitter, RecursiveCharacterTextSplitter\n",
    "from bs4 import BeautifulSoup\n",
    "\n",
    "# Define the headers to split on for the HTMLHeaderTextSplitter\n",
    "headers_to_split_on = [\n",
    "    (\"h1\", \"Header 1\"),\n",
    "    (\"h2\", \"Header 2\"),\n",
    "]\n",
    "# Create an instance of the HTMLHeaderTextSplitter\n",
    "html_splitter = HTMLHeaderTextSplitter(headers_to_split_on=headers_to_split_on)\n",
    "\n",
    "# Use the embedding model parameters.\n",
    "chunk_size = MAX_SEQ_LENGTH - HF_EOS_TOKEN_LENGTH\n",
    "chunk_overlap = np.round(chunk_size * 0.10, 0)\n",
    "print(f\"chunk_size: {chunk_size}, chunk_overlap: {chunk_overlap}\")\n",
    "\n",
    "# Create an instance of the RecursiveCharacterTextSplitter\n",
    "child_splitter = RecursiveCharacterTextSplitter(\n",
    "    chunk_size = chunk_size,\n",
    "    chunk_overlap = chunk_overlap,\n",
    "    length_function = len,\n",
    ")\n",
    "\n",
    "# Split the HTML text using the HTMLHeaderTextSplitter.\n",
    "start_time = time.time()\n",
    "html_header_splits = []\n",
    "for doc in docs:\n",
    "    soup = BeautifulSoup(doc.page_content, 'html.parser')\n",
    "    splits = html_splitter.split_text(str(soup))\n",
    "    for split in splits:\n",
    "        # Add the source URL and header values to the metadata\n",
    "        metadata = {}\n",
    "        new_text = split.page_content\n",
    "        for header_name, metadata_header_name in headers_to_split_on:\n",
    "            header_value = new_text.split(\"¶ \")[0].strip()\n",
    "            metadata[header_name] = header_value\n",
    "            try:\n",
    "                new_text = new_text.split(\"¶ \")[1].strip()\n",
    "            except:\n",
    "                break\n",
    "        split.metadata = {\n",
    "            **metadata,\n",
    "            \"source\": doc.metadata[\"source\"]\n",
    "        }\n",
    "        # Add the header to the text\n",
    "        split.page_content = split.page_content\n",
    "    html_header_splits.extend(splits)\n",
    "\n",
    "# Split the documents further into smaller, recursive chunks.\n",
    "chunks = child_splitter.split_documents(html_header_splits)\n",
    "\n",
    "end_time = time.time()\n",
    "print(f\"chunking time: {end_time - start_time}\")\n",
    "print(f\"docs: {len(docs)}, split into: {len(html_header_splits)}\")\n",
    "print(f\"split into chunks: {len(chunks)}, type: list of {type(chunks[0])}\") \n",
    "\n",
    "# Inspect a chunk.\n",
    "print()\n",
    "print(\"Looking at a sample chunk...\")\n",
    "print(chunks[0].page_content[:100])\n",
    "print(chunks[0].metadata)\n",
    "\n",
    "# # TODO - Uncomment to print child splits with their associated header metadata.\n",
    "# print()\n",
    "# for child in chunks:\n",
    "#     print(f\"Content: {child.page_content}\")\n",
    "#     print(f\"Metadata: {child.metadata}\")\n",
    "#     print()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "512130a3",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Installation¶ Installing via pip¶ PyMilvus is in the Python Package Index. PyMilvus only support pyt\n",
      "{'h1': 'Installation', 'h2': 'Installing via pip', 'source': 'https://pymilvus.readthedocs.io/en/latest/install.html'}\n"
     ]
    }
   ],
   "source": [
    "# Clean up the metadata urls\n",
    "for doc in chunks:\n",
    "    new_url = doc.metadata[\"source\"]\n",
    "    new_url = new_url.replace(\"../RAG/rtdocs\", \"https:/\")\n",
    "    doc.metadata.update({\"source\": new_url})\n",
    "\n",
    "print(chunks[0].page_content[:100])\n",
    "print(chunks[0].metadata)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d9bd8153",
   "metadata": {},
   "source": [
    "## Insert data into Milvus\n",
    "\n",
    "For each original text chunk, we'll write the quadruplet (`vector, text, source, h1, h2`) into the database.\n",
    "\n",
    "<div>\n",
    "<img src=\"../../images/db_insert.png\" width=\"80%\"/>\n",
    "</div>\n",
    "\n",
    "**The Milvus Client wrapper can only handle loading data from a list of dictionaries.**\n",
    "\n",
    "Otherwise, in general, Milvus supports loading data from:\n",
    "- pandas dataframes \n",
    "- list of dictionaries\n",
    "\n",
    "Below, we use the embedding model provided by HuggingFace, download its checkpoint, and run it locally as the encoder.  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Start inserting entities\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1/1 [00:01<00:00,  1.65s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Milvus Client insert time for 156 vectors: 1.6547510623931885 seconds\n"
     ]
    }
   ],
   "source": [
    "# STEP 5. INSERT CHUNKS AND EMBEDDINGS IN ZILLIZ.\n",
    "\n",
    "# Convert chunks to a list of dictionaries.\n",
    "chunk_list = []\n",
    "for chunk in chunks:\n",
    "\n",
    "    # Generate embeddings using encoder from HuggingFace.\n",
    "    embeddings = torch.tensor(encoder.encode([chunk.page_content]))\n",
    "    embeddings = F.normalize(embeddings, p=2, dim=1)\n",
    "    converted_values = list(map(np.float32, embeddings))[0]\n",
    "    \n",
    "    # Only use h1, h2. Truncate the metadata in case too long.\n",
    "    try:\n",
    "        h2 = chunk.metadata['h2'][:50]\n",
    "    except:\n",
    "        h2 = \"\"\n",
    "    # Assemble embedding vector, original text chunk, metadata.\n",
    "    chunk_dict = {\n",
    "        'vector': converted_values,\n",
    "        'chunk': chunk.page_content,\n",
    "        'source': chunk.metadata['source'],\n",
    "        'h1': chunk.metadata['h1'][:50],\n",
    "        'h2': h2,\n",
    "    }\n",
    "    chunk_list.append(chunk_dict)\n",
    "\n",
    "# Insert data into the Milvus collection.\n",
    "print(\"Start inserting entities\")\n",
    "start_time = time.time()\n",
    "insert_result = mc.insert(\n",
    "    COLLECTION_NAME,\n",
    "    data=chunk_list,\n",
    "    progress_bar=True)\n",
    "end_time = time.time()\n",
    "print(f\"Milvus Client insert time for {len(chunk_list)} vectors: {end_time - start_time} seconds\")\n",
    "\n",
    "# After final entity is inserted, call flush to stop growing segments left in memory.\n",
    "mc.flush(COLLECTION_NAME)\n",
    "\n",
    "# Milvus Client insert time for 156 vectors: 1.283660888671875 seconds"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "02c589ff",
   "metadata": {},
   "source": [
    "## Ask a question about your data\n",
    "\n",
    "So far in this demo notebook: \n",
    "1. Your custom data has been mapped into a vector embedding space\n",
    "2. Those vector embeddings have been saved into a vector database\n",
    "\n",
    "Next, you can ask a question about your custom data!\n",
    "\n",
    "💡 In LLM vocabulary:\n",
    "> **Query** is the generic term for user questions.  \n",
    "A query is a list of multiple individual questions, up to maybe 1000 different questions!\n",
    "\n",
    "> **Question** usually refers to a single user question.  \n",
    "In our example below, the user question is \"What is AUTOINDEX in Milvus Client?\"\n",
    "\n",
    "> **Semantic Search** = very fast search of the entire knowledge base to find the `TOP_K` documentation chunks with the closest embeddings to the user's query.\n",
    "\n",
    "💡 The same model should always be used for consistency for all the embeddings data and the query."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "5e7f41f4",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Question</th>\n",
       "      <th>ground_truth_answer</th>\n",
       "      <th>Uri</th>\n",
       "      <th>retrieval_chunk_text</th>\n",
       "      <th>H1</th>\n",
       "      <th>H2</th>\n",
       "      <th>assistant_answer</th>\n",
       "      <th>Score</th>\n",
       "      <th>Reason</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>What do the parameters for HNSW mean?\\n</td>\n",
       "      <td>- M: maximum degree of nodes in a layer of the...</td>\n",
       "      <td>https://pymilvus.readthedocs.io/en/latest/para...</td>\n",
       "      <td>performance, HNSW limits the maximum degree of...</td>\n",
       "      <td>Index</td>\n",
       "      <td>Milvus support to create index to accelerate v...</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>What are HNSW good default parameters when dat...</td>\n",
       "      <td>M=16, efConstruction=32, ef=32</td>\n",
       "      <td>https://pymilvus.readthedocs.io/en/latest/para...</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>what is the default distance metric used in AU...</td>\n",
       "      <td>Trick answer:  IP inner product, not yet updat...</td>\n",
       "      <td>https://pymilvus.readthedocs.io/en/latest/tuto...</td>\n",
       "      <td>The attributes of collection can be extracted ...</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>How did New York City get its name?</td>\n",
       "      <td>In the 1600’s, the Dutch planted a trading pos...</td>\n",
       "      <td>https://en.wikipedia.org/wiki/New_York_City</td>\n",
       "      <td>Etymology\\nSee also: Nicknames of New York Cit...</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                            Question  \\\n",
       "0            What do the parameters for HNSW mean?\\n   \n",
       "1  What are HNSW good default parameters when dat...   \n",
       "2  what is the default distance metric used in AU...   \n",
       "3                How did New York City get its name?   \n",
       "\n",
       "                                 ground_truth_answer  \\\n",
       "0  - M: maximum degree of nodes in a layer of the...   \n",
       "1                     M=16, efConstruction=32, ef=32   \n",
       "2  Trick answer:  IP inner product, not yet updat...   \n",
       "3  In the 1600’s, the Dutch planted a trading pos...   \n",
       "\n",
       "                                                 Uri  \\\n",
       "0  https://pymilvus.readthedocs.io/en/latest/para...   \n",
       "1  https://pymilvus.readthedocs.io/en/latest/para...   \n",
       "2  https://pymilvus.readthedocs.io/en/latest/tuto...   \n",
       "3        https://en.wikipedia.org/wiki/New_York_City   \n",
       "\n",
       "                                retrieval_chunk_text     H1  \\\n",
       "0  performance, HNSW limits the maximum degree of...  Index   \n",
       "1                                                NaN    NaN   \n",
       "2  The attributes of collection can be extracted ...    NaN   \n",
       "3  Etymology\\nSee also: Nicknames of New York Cit...    NaN   \n",
       "\n",
       "                                                  H2  assistant_answer  Score  \\\n",
       "0  Milvus support to create index to accelerate v...               NaN    NaN   \n",
       "1                                                NaN               NaN    NaN   \n",
       "2                                                NaN               NaN    NaN   \n",
       "3                                                NaN               NaN    NaN   \n",
       "\n",
       "   Reason  \n",
       "0     NaN  \n",
       "1     NaN  \n",
       "2     NaN  \n",
       "3     NaN  "
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "4\n",
      "query = 0              What do the parameters for HNSW mean?\\n\n",
      "1    What are HNSW good default parameters when dat...\n",
      "2    what is the default distance metric used in AU...\n",
      "3                  How did New York City get its name?\n",
      "Name: Question, dtype: object\n",
      "4\n",
      "truth_answers = 0    - M: maximum degree of nodes in a layer of the...\n",
      "1                       M=16, efConstruction=32, ef=32\n",
      "2    Trick answer:  IP inner product, not yet updat...\n",
      "3    In the 1600’s, the Dutch planted a trading pos...\n",
      "Name: ground_truth_answer, dtype: object\n",
      "4\n",
      "truth_uris = 0    https://pymilvus.readthedocs.io/en/latest/para...\n",
      "1    https://pymilvus.readthedocs.io/en/latest/para...\n",
      "2    https://pymilvus.readthedocs.io/en/latest/tuto...\n",
      "3          https://en.wikipedia.org/wiki/New_York_City\n",
      "Name: Uri, dtype: object\n"
     ]
    }
   ],
   "source": [
    "# Read questions and ground truth answers into a pandas dataframe.\n",
    "import pandas as pd\n",
    "\n",
    "# Read ground truth answers from file.\n",
    "eval_df = pd.read_csv(\"../../../christy_coding_scratch/data/milvus_ground_truth.csv\", \n",
    "                      header=0, skip_blank_lines=True)\n",
    "display(eval_df.head())\n",
    "\n",
    "# Get all the questions.\n",
    "query = eval_df.Question\n",
    "print(len(query))\n",
    "print(f\"query = {query}\")\n",
    "\n",
    "# Get all the truth answers.\n",
    "truth_answers = eval_df.ground_truth_answer\n",
    "print(len(truth_answers))\n",
    "print(f\"truth_answers = {truth_answers}\")\n",
    "\n",
    "# Get all the truth uris.\n",
    "truth_uris = eval_df.Uri\n",
    "print(len(truth_uris))\n",
    "print(f\"truth_uris = {truth_uris}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "ac09a544",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "question = What do the parameters for HNSW mean?\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# Choose a question, answer, uri, and chunk.\n",
    "QUESTION_NUMBER = 0\n",
    "SAMPLE_QUESTION = query[QUESTION_NUMBER]\n",
    "print(f\"question = {SAMPLE_QUESTION}\")\n",
    "\n",
    "truth_answer = truth_answers[QUESTION_NUMBER]\n",
    "truth_uri = truth_uris[QUESTION_NUMBER]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9ea29411",
   "metadata": {},
   "source": [
    "## Execute a vector search\n",
    "\n",
    "Search Milvus using [PyMilvus API](https://milvus.io/docs/search.md).\n",
    "\n",
    "💡 By their nature, vector searches are \"semantic\" searches.  For example, if you were to search for \"leaky faucet\": \n",
    "> **Traditional Key-word Search** - either or both words \"leaky\", \"faucet\" would have to match some text in order to return a web page or link text to the document.\n",
    "\n",
    "> **Semantic search** - results containing words \"drippy\" \"taps\" would be returned as well because these words mean the same thing even though they are different words,"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "6504d2a7",
   "metadata": {},
   "outputs": [],
   "source": [
    "def search_milvus(mc, question, top_k):\n",
    "    # Wrap the mc.search() call in a function\n",
    "\n",
    "    # Embed the question using the same encoder.\n",
    "    query_embeddings = _utils.embed_query(encoder, [question])\n",
    "\n",
    "    # Return top k results with HNSW index.\n",
    "    SEARCH_PARAMS = dict({\n",
    "        # Re-use index param for num_candidate_nearest_neighbors.\n",
    "        \"ef\": INDEX_PARAMS['efConstruction']\n",
    "        })\n",
    "\n",
    "    # Define output fields to return.\n",
    "    OUTPUT_FIELDS = [\"h1\", \"h2\", \"source\", \"chunk\"]\n",
    "\n",
    "    answers = mc.search(\n",
    "        COLLECTION_NAME,\n",
    "        data=query_embeddings, \n",
    "        search_params=SEARCH_PARAMS,\n",
    "        output_fields=OUTPUT_FIELDS, \n",
    "        # Milvus can utilize metadata in boolean expressions to filter search.\n",
    "        # filter=\"\",\n",
    "        limit=top_k,\n",
    "        consistency_level=\"Eventually\"\n",
    "    )\n",
    "    return answers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "89642119",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Milvus Client search time for 156 vectors: 0.20205092430114746 seconds\n",
      "type: <class 'list'>, count: 3\n",
      "chunk_answer: performance, HNSW limits the maximum degree of nodes on each layer of the graph to M. In addition, you can use efConstruction (when building index) or\n"
     ]
    }
   ],
   "source": [
    "# RETRIEVAL USING MILVUS API.\n",
    "\n",
    "# # Not needed with Milvus Client API.\n",
    "# mc.load()\n",
    "\n",
    "# Define output fields to return.\n",
    "OUTPUT_FIELDS = [\"h1\", \"h2\", \"source\", \"chunk\"]\n",
    "\n",
    "# Run semantic vector search using your query and the vector database.\n",
    "TOP_K = 3\n",
    "start_time = time.time()\n",
    "result = search_milvus(mc, SAMPLE_QUESTION, TOP_K)\n",
    "\n",
    "elapsed_time = time.time() - start_time\n",
    "print(f\"Milvus Client search time for {len(chunk_list)} vectors: {elapsed_time} seconds\")\n",
    "\n",
    "# Inspect search result.\n",
    "print(f\"type: {type(result[0])}, count: {len(result[0])}\")\n",
    "\n",
    "# Milvus Client search time for 156 vectors: 0.1264362335205078 seconds\n",
    "# type: <class 'list'>, count: 3\n",
    "\n",
    "# Extract the retrieval answer.\n",
    "retrieval_answer = result[0][0]['entity']['chunk']\n",
    "print(f\"chunk_answer: {retrieval_answer[:150]}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "410cc3be",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "LangChain Zilliz search time for 156 vectors: 0.7929909229278564 seconds\n",
      "count retrieval answers: 4\n",
      "RESULTS FOR QUESTION #1:\n",
      "top_k: 1:\n",
      "{'id': 446268198608633780, 'distance': 0.7123057842254639, 'entity': {'chunk': 'performance, HNSW limits the maximum degree of nodes on each layer of the graph to M. In addition, you can use efConstruction (when building index) or ef (when searching targets) to specify a search range. building parameters: M: Maximum degree of the node. efConstruction: Take the effect in stage of index construction. # HNSW client.create_index(collection_name, IndexType.HNSW, { \"M\": 16, # int. 4~64 \"efConstruction\": 40 # int. 8~512 } ) search parameters: ef: Take the effect in stage of search scope,', 'source': 'https://pymilvus.readthedocs.io/en/latest/param.html', 'h1': 'Index', 'h2': 'Milvus support to create index to accelerate vecto'}}\n",
      "RESULTS FOR QUESTION #2:\n",
      "top_k: 1:\n",
      "{'id': 446268198608633766, 'distance': 0.7082682847976685, 'entity': {'chunk': 'Metrics. Vector Index¶ FLAT IVF_FLAT IVF_SQ8 IVF_SQ8_H IVF_PQ HNSW ANNOY RNSG FLAT¶ If FLAT index is used, the vectors are stored in an array of float/binary data without any compression. during searching vectors, all indexed vectors are decoded sequentially and compared to the query vectors. FLAT index provides 100% query recall rate. Compared to other indexes, it is the most efficient indexing method when the number of queries is small. The inserted and index-inbuilt vectors and index-dropped vectors', 'source': 'https://pymilvus.readthedocs.io/en/latest/param.html', 'h1': 'Index', 'h2': 'Milvus support to create index to accelerate vecto'}}\n",
      "RESULTS FOR QUESTION #3:\n",
      "top_k: 1:\n",
      "{'id': 446268198608633747, 'distance': 0.7685069441795349, 'entity': {'chunk': \"metric_type=) The attributes of collection can be extracted from info. >>> info.collection_name 'demo_film_tutorial' >>> info.dimension 8 >>> info.index_file_size 2048 >>> info.metric_type This tutorial is a basic intro tutorial, building index won’t be covered by this tutorial. If you want to go further into Milvus with indexes, it’s recommended to check our index examples. If you’re already known about indexes from index examples, and you want a full lists of params supported by PyMilvus, you check out\", 'source': 'https://pymilvus.readthedocs.io/en/latest/tutorial.html', 'h1': 'Tutorial', 'h2': 'This is a basic introduction to Milvus by PyMilvus'}}\n",
      "RESULTS FOR QUESTION #4:\n",
      "top_k: 1:\n",
      "{'id': 446268198608633827, 'distance': 0.5131039619445801, 'entity': {'chunk': 'def name(self): return self._name @property def handler(self): return self._handler @deprecated def connect(self, host=None, port=None, uri=None, timeout=2): \"\"\" Deprecated \"\"\" if self.connected() and self._connected: return Status(message=\"You have already connected {} !\".format(self._uri), code=Status.CONNECT_FAILED) if self._stub is None: self._init(host, port, uri, handler=self._handler) if self.ping(timeout): self._status = Status(message=\"Connected\") self._connected = True return self._status return', 'source': 'https://pymilvus.readthedocs.io/en/latest/_modules/milvus/client/stub.html', 'h1': 'Source code for milvus.client.stub # -*- coding: U', 'h2': ''}}\n"
     ]
    }
   ],
   "source": [
    "# Repeat Retrieval step, but loop through list of questions.\n",
    "\n",
    "# # Not needed with Milvus Client API.\n",
    "# mc.load()\n",
    "\n",
    "# Run similarity_search for all questions in the query list\n",
    "TOP_K = 1\n",
    "start_time = time.time()\n",
    "retrieved_results = [search_milvus(mc, question, TOP_K)\n",
    "           for question in query]\n",
    "elapsed_time = time.time() - start_time\n",
    "print(f\"LangChain Zilliz search time for {len(chunks)} vectors: {elapsed_time} seconds\")\n",
    "\n",
    "# Extract list of 0th top_k chunks per question.\n",
    "retrieval_answers = [result[0][0]['entity']['chunk'] for result in retrieved_results]\n",
    "print(f\"count retrieval answers: {len(retrieval_answers)}\")\n",
    "\n",
    "# TODO: Uncomment to print the results\n",
    "for i, result_list in enumerate(retrieved_results):\n",
    "    print(f\"RESULTS FOR QUESTION #{i+1}:\")\n",
    "    for j, result in enumerate(result_list):\n",
    "        for k, top_k_result in enumerate(result):\n",
    "            print(f\"top_k: {k+1}:\")\n",
    "            print(top_k_result)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Assemble and inspect the search results from your docs.\n",
    "\n",
    "The search result is in the variable `results[0]` of type `'pymilvus.orm.search.SearchResult'`.  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "QUESTION #1, top_k = 1\n",
      "Context: performance, HNSW limits the maximum degree of nodes on each layer of the graph to M. In addition, you can use efConstruction (when building index) or\n",
      "Metadata: {'h1': 'Index', 'h2': 'Milvus support to create index to accelerate vecto', 'source': 'https://pymilvus.readthedocs.io/en/latest/param.html'}\n",
      "\n",
      "QUESTION #2, top_k = 1\n",
      "Context: Metrics. Vector Index¶ FLAT IVF_FLAT IVF_SQ8 IVF_SQ8_H IVF_PQ HNSW ANNOY RNSG FLAT¶ If FLAT index is used, the vectors are stored in an array of float\n",
      "Metadata: {'h1': 'Index', 'h2': 'Milvus support to create index to accelerate vecto', 'source': 'https://pymilvus.readthedocs.io/en/latest/param.html'}\n",
      "\n",
      "QUESTION #3, top_k = 1\n",
      "Context: metric_type=) The attributes of collection can be extracted from info. >>> info.collection_name 'demo_film_tutorial' >>> info.dimension 8 >>> info.ind\n",
      "Metadata: {'h1': 'Tutorial', 'h2': 'This is a basic introduction to Milvus by PyMilvus', 'source': 'https://pymilvus.readthedocs.io/en/latest/tutorial.html'}\n",
      "\n",
      "QUESTION #4, top_k = 1\n",
      "Context: def name(self): return self._name @property def handler(self): return self._handler @deprecated def connect(self, host=None, port=None, uri=None, time\n",
      "Metadata: {'h1': 'Source code for milvus.client.stub # -*- coding: U', 'h2': '', 'source': 'https://pymilvus.readthedocs.io/en/latest/_modules/milvus/client/stub.html'}\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# Assemble `num_shot_answers` retrieved 1st context and context metadata.\n",
    "METADATA_FIELDS = [f for f in OUTPUT_FIELDS if f != 'chunk']\n",
    "all_formatted_results = []\n",
    "all_context = []\n",
    "all_context_metadata = []\n",
    "\n",
    "# Iterate over the results for each question\n",
    "for question_results in retrieved_results:\n",
    "    # Assemble the context and context metadata for the current question\n",
    "    formatted_results, context, context_metadata = _utils.client_assemble_retrieved_context(\n",
    "        question_results, metadata_fields=METADATA_FIELDS, num_shot_answers=3)\n",
    "    \n",
    "    # Append the formatted results, context, and context metadata to the corresponding lists\n",
    "    all_formatted_results.append(formatted_results)\n",
    "    all_context.append(context)\n",
    "    all_context_metadata.append(context_metadata)\n",
    "\n",
    "# # TODO - Uncomment to loop through each context and metadata and print.\n",
    "for i, (context, context_metadata) in enumerate(zip(all_context, all_context_metadata)):\n",
    "    for j in range(len(context)):\n",
    "        print(f\"QUESTION #{i+1}, top_k = {j+1}\")\n",
    "        print(f\"Context: {context[j][:150]}\")\n",
    "        print(f\"Metadata: {context_metadata[j]}\")\n",
    "        print()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "15d7e797",
   "metadata": {},
   "source": [
    "## Evaluate using an open source LLM as a judge.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "id": "78126e85",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "truth: - M: maximum degree of nodes in a layer of the graph. - efConstruction: number of nearest neighbors \n",
      "\n",
      "retrieval: performance, HNSW limits the maximum degree of nodes on each layer of the graph to M. In addition, y\n",
      "\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Question</th>\n",
       "      <th>ground_truth_answer</th>\n",
       "      <th>Uri</th>\n",
       "      <th>retrieval_chunk_text</th>\n",
       "      <th>H1</th>\n",
       "      <th>H2</th>\n",
       "      <th>assistant_answer</th>\n",
       "      <th>Score</th>\n",
       "      <th>Reason</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>What do the parameters for HNSW mean?\\n</td>\n",
       "      <td>- M: maximum degree of nodes in a layer of the...</td>\n",
       "      <td>https://pymilvus.readthedocs.io/en/latest/para...</td>\n",
       "      <td>performance, HNSW limits the maximum degree of...</td>\n",
       "      <td>Index</td>\n",
       "      <td>Milvus support to create index to accelerate v...</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>What are HNSW good default parameters when dat...</td>\n",
       "      <td>M=16, efConstruction=32, ef=32</td>\n",
       "      <td>https://pymilvus.readthedocs.io/en/latest/para...</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                            Question  \\\n",
       "0            What do the parameters for HNSW mean?\\n   \n",
       "1  What are HNSW good default parameters when dat...   \n",
       "\n",
       "                                 ground_truth_answer  \\\n",
       "0  - M: maximum degree of nodes in a layer of the...   \n",
       "1                     M=16, efConstruction=32, ef=32   \n",
       "\n",
       "                                                 Uri  \\\n",
       "0  https://pymilvus.readthedocs.io/en/latest/para...   \n",
       "1  https://pymilvus.readthedocs.io/en/latest/para...   \n",
       "\n",
       "                                retrieval_chunk_text     H1  \\\n",
       "0  performance, HNSW limits the maximum degree of...  Index   \n",
       "1                                                NaN    NaN   \n",
       "\n",
       "                                                  H2  assistant_answer  Score  \\\n",
       "0  Milvus support to create index to accelerate v...               NaN    NaN   \n",
       "1                                                NaN               NaN    NaN   \n",
       "\n",
       "   Reason  \n",
       "0     NaN  \n",
       "1     NaN  "
      ]
     },
     "execution_count": 55,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Choose a single truth, retrieval text pair.\n",
    "truth = truth_answer\n",
    "retrieval = retrieval_answer\n",
    "\n",
    "print(f\"truth: {truth[:100]}\\n\")\n",
    "print(f\"retrieval: {retrieval[:100]}\\n\")\n",
    "eval_df.head(2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "id": "75831ad4",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Try using simple LLM-as-judge with zero-shot prompt.\n",
    "\n",
    "import json, pprint\n",
    "import openai, tiktoken\n",
    "from openai import OpenAI\n",
    "\n",
    "# Define the generation llm model to use.\n",
    "LLM_NAME = \"gpt-3.5-turbo-1106\"\n",
    "TEMPERATURE = 0.0\n",
    "\n",
    "# Reasonable values for the penalty coefficients are around 0.1 to 1 if the aim is to just reduce repition \n",
    "# somewhat. To strongly suppress repetition, set coefficients = 2.\n",
    "FREQUENCY_PENALTY = 2\n",
    "\n",
    "# See how to save api key in env variable.\n",
    "# https://help.openai.com/en/articles/5112595-best-practices-for-api-key-safety\n",
    "openai_client = OpenAI(\n",
    "    # This is the default and can be omitted\n",
    "    # api_key=os.environ.get(\"OPENAI_API_KEY\"),\n",
    "    api_key=os.environ[\"OPENAI_API_KEY\"],\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 64,
   "id": "a9ba845c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Function to call OpenAI LLM as judge on zero-shot task.\n",
    "def get_openai_score(llm_name, user_prompt,\n",
    "                     temperature=0.0, random_seed=415, frequency_penalty=2, max_tokens=500):\n",
    "\n",
    "    SYSTEM_PROMPT = f\"\"\"\n",
    "    You are a fair, impartial judge.\n",
    "    \"\"\"\n",
    "        \n",
    "    # Define the OpenAIEvaluator.\n",
    "    responses = openai_client.chat.completions.create(\n",
    "        response_format={\n",
    "            \"type\": \"json_object\", \n",
    "            # \"schema\": Result.schema_json()\n",
    "        },\n",
    "        messages=[\n",
    "            # {\"role\": \"system\", \"content\": \"You are a helpful assistant.\"},  # background tone\n",
    "            # {\"role\": \"user\", \"content\": \"Who won the world series in 2020?\"}, # question\n",
    "            # Use assistant messages to provide what was previously said in multi-turn conversations.\n",
    "            # {\"role\": \"assistant\", \"content\": \"The Los Angeles Dodgers won the World Series in 2020.\"}, \n",
    "\n",
    "            {\"role\": \"system\", \"content\": SYSTEM_PROMPT},\n",
    "            {\"role\": \"user\", \"content\": user_prompt}\n",
    "        ],\n",
    "        model=llm_name,\n",
    "        temperature=temperature, # the degree of randomness of the model's output\n",
    "        seed=random_seed,  # for reproducibility\n",
    "        frequency_penalty=frequency_penalty, # allowed amount of repitition in the model's output\n",
    "        max_tokens=max_tokens # maximum number of tokens the model can output\n",
    "    )\n",
    "\n",
    "    # Make sure total_tokens < 4096.\n",
    "    token_dict = {\n",
    "        'prompt_tokens':responses.usage.prompt_tokens,\n",
    "        'completion_tokens':responses.usage.completion_tokens,\n",
    "        'total_tokens':responses.usage.total_tokens,\n",
    "    }\n",
    "\n",
    "    # Print answer as a JSON object.\n",
    "    openai_response = responses.choices[0].message.content\n",
    "    json_response = json.loads(openai_response)\n",
    "    json_response # single json object with 3 fields\n",
    "\n",
    "    # Create a DataFrame from a list of dictionaries.\n",
    "    response_df = pd.DataFrame([json_response])\n",
    "    token_df = pd.DataFrame([token_dict])\n",
    "\n",
    "    return response_df, token_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 70,
   "id": "f9d654f1",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'\\nFor each question_number: 0, \\ncalculate the similarity between these two texts, using semantic meaning, not word order. \\nText1: - M: maximum degree of nodes in a layer of the graph.\\u2028- efConstruction: number of nearest neighbors to consider when connecting nodes in the graph.\\u2028- ef: number of nearest neighbors to consider when searching for similar vectors.  , Text2: performance, HNSW limits the maximum degree of nodes on each layer of the graph to M. In addition, you can use efConstruction (when building index) or ef (when searching targets) to specify a search range. building parameters: M: Maximum degree of the node. efConstruction: Take the effect in stage of index construction. # HNSW client.create_index(collection_name, IndexType.HNSW, { \"M\": 16, # int. 4~64 \"efConstruction\": 40 # int. 8~512 } ) search parameters: ef: Take the effect in stage of search scope,.\\nCalculate llm_zero_shot_similarity_score as a number between 0 and 4, where 4 indicates identical content and 0 indicates completely different content.\\nOutput JSON fields:\\n- question_number\\n- text1\\n- text2 \\n- llm_zero_shot_similarity_score\\n'"
      ]
     },
     "execution_count": 70,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "text1 = truth\n",
    "text2 = retrieval\n",
    "question_number = QUESTION_NUMBER\n",
    "\n",
    "ZERO_SHOT_PROMPT = f\"\"\"\n",
    "For each question_number: {question_number}, \n",
    "calculate the similarity between these two texts, using semantic meaning, not word order. \n",
    "Text1: {text1}, Text2: {text2}.\n",
    "Calculate llm_zero_shot_similarity_score as a number between 0 and 4, where 4 indicates identical content and 0 indicates completely different content.\n",
    "Output JSON fields:\n",
    "- question_number\n",
    "- text1\n",
    "- text2 \n",
    "- llm_zero_shot_similarity_score\n",
    "\"\"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "id": "0304a75c",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
      "To disable this warning, you can either:\n",
      "\t- Avoid using `tokenizers` before the fork if possible\n",
      "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "LLM as judge zero-shot took: 18.57730984687805 seconds\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>question_number</th>\n",
       "      <th>text1</th>\n",
       "      <th>text2</th>\n",
       "      <th>llm_zero_shot_similarity_score</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0</td>\n",
       "      <td>In the 1600’s, the Dutch planted a trading pos...</td>\n",
       "      <td>def name(self): return self._name @property de...</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   question_number                                              text1  \\\n",
       "0                0  In the 1600’s, the Dutch planted a trading pos...   \n",
       "\n",
       "                                               text2  \\\n",
       "0  def name(self): return self._name @property de...   \n",
       "\n",
       "   llm_zero_shot_similarity_score  \n",
       "0                               0  "
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>prompt_tokens</th>\n",
       "      <th>completion_tokens</th>\n",
       "      <th>total_tokens</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>322</td>\n",
       "      <td>230</td>\n",
       "      <td>552</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   prompt_tokens  completion_tokens  total_tokens\n",
       "0            322                230           552"
      ]
     },
     "execution_count": 54,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Test zero-shot LLM as judge on a single question.\n",
    "# Doc Openai function calling: https://platform.openai.com/docs/guides/function-calling\n",
    "\n",
    "# # CAREFUL!! THIS COSTS MONEY!!\n",
    "# start_time = time.time()\n",
    "# result_df, token_df = get_openai_score(LLM_NAME, ZERO_SHOT_PROMPT, TEMPERATURE)\n",
    "# elapsed_time = time.time() - start_time\n",
    "# print(f\"LLM as judge zero-shot took: {elapsed_time} seconds\")\n",
    "\n",
    "display(result_df.head())  # score = 2\n",
    "token_df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 73,
   "id": "4f2bae81",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "LLM as judge zero-shot took: 52.86700487136841 seconds\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>question_number</th>\n",
       "      <th>text1</th>\n",
       "      <th>text2</th>\n",
       "      <th>llm_zero_shot_similarity_score</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0</td>\n",
       "      <td>- M: maximum degree of nodes in a layer of the...</td>\n",
       "      <td>performance, HNSW limits the maximum degree of...</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1</td>\n",
       "      <td>M=16, efConstruction=32, ef=32</td>\n",
       "      <td>Metrics. Vector Index¶ FLAT IVF_FLAT IVF_SQ8 I...</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>2</td>\n",
       "      <td>Trick answer: IP inner product, not yet update...</td>\n",
       "      <td>metric_type=) The attributes of collection can...</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>3</td>\n",
       "      <td>In the 1600’s, the Dutch planted a trading pos...</td>\n",
       "      <td>def name(self): return self._name @property de...</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   question_number                                              text1  \\\n",
       "0                0  - M: maximum degree of nodes in a layer of the...   \n",
       "1                1                     M=16, efConstruction=32, ef=32   \n",
       "2                2  Trick answer: IP inner product, not yet update...   \n",
       "3                3  In the 1600’s, the Dutch planted a trading pos...   \n",
       "\n",
       "                                               text2  \\\n",
       "0  performance, HNSW limits the maximum degree of...   \n",
       "1  Metrics. Vector Index¶ FLAT IVF_FLAT IVF_SQ8 I...   \n",
       "2  metric_type=) The attributes of collection can...   \n",
       "3  def name(self): return self._name @property de...   \n",
       "\n",
       "   llm_zero_shot_similarity_score  \n",
       "0                               2  \n",
       "1                               0  \n",
       "2                               0  \n",
       "3                               0  "
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "This is an example prompt with placeholders for question_number: 4, text1: - M: maximum degree of nodes in a layer of the graph. - efConstruction: number of nearest neighbors to consider when connecting nodes in the graph. - ef: number of nearest neighbors to consider when searching for similar vectors.  , and text2: performance, HNSW limits the maximum degree of nodes on each layer of the graph to M. In addition, you can use efConstruction (when building index) or ef (when searching targets) to specify a search range. building parameters: M: Maximum degree of the node. efConstruction: Take the effect in stage of index construction. # HNSW client.create_index(collection_name, IndexType.HNSW, { \"M\": 16, # int. 4~64 \"efConstruction\": 40 # int. 8~512 } ) search parameters: ef: Take the effect in stage of search scope,. The curly braces here {} are escaped.\n",
      "This is an example prompt with placeholders for question_number: 4, text1: M=16, efConstruction=32, ef=32, and text2: Metrics. Vector Index¶ FLAT IVF_FLAT IVF_SQ8 IVF_SQ8_H IVF_PQ HNSW ANNOY RNSG FLAT¶ If FLAT index is used, the vectors are stored in an array of float/binary data without any compression. during searching vectors, all indexed vectors are decoded sequentially and compared to the query vectors. FLAT index provides 100% query recall rate. Compared to other indexes, it is the most efficient indexing method when the number of queries is small. The inserted and index-inbuilt vectors and index-dropped vectors. The curly braces here {} are escaped.\n",
      "This is an example prompt with placeholders for question_number: 4, text1: Trick answer:  IP inner product, not yet updated in documentation still says L2., and text2: metric_type=) The attributes of collection can be extracted from info. >>> info.collection_name 'demo_film_tutorial' >>> info.dimension 8 >>> info.index_file_size 2048 >>> info.metric_type This tutorial is a basic intro tutorial, building index won’t be covered by this tutorial. If you want to go further into Milvus with indexes, it’s recommended to check our index examples. If you’re already known about indexes from index examples, and you want a full lists of params supported by PyMilvus, you check out. The curly braces here {} are escaped.\n",
      "This is an example prompt with placeholders for question_number: 4, text1: In the 1600’s, the Dutch planted a trading post on the southern tip of the island and named it New Amsterdam, after their capital city in the Netherlands. In 1664, the English seized control of the area from the Dutch and renamed it New York in honor of the Duke of York, who later became King James II of England. The name \"New York\" was chosen to reflect the English influence and to maintain a connection with the Duke of York., and text2: def name(self): return self._name @property def handler(self): return self._handler @deprecated def connect(self, host=None, port=None, uri=None, timeout=2): \"\"\" Deprecated \"\"\" if self.connected() and self._connected: return Status(message=\"You have already connected {} !\".format(self._uri), code=Status.CONNECT_FAILED) if self._stub is None: self._init(host, port, uri, handler=self._handler) if self.ping(timeout): self._status = Status(message=\"Connected\") self._connected = True return self._status return. The curly braces here {} are escaped.\n"
     ]
    }
   ],
   "source": [
    "# Use zero-shot LLM as judge on all question/retrieval pairs.\n",
    "# Drop token counting for now.\n",
    "\n",
    "ZERO_SHOT_PROMPT_TEMPLATE = \"\"\"For each question_number: {question_number}, \n",
    "calculate the similarity between these two texts, using semantic meaning, not word order. \n",
    "Text1: {text1}, Text2: {text2}.\n",
    "Calculate llm_zero_shot_similarity_score as a number between 0 and 4, where 4 indicates identical content and 0 indicates completely different content.\n",
    "Output JSON fields:\n",
    "- question_number\n",
    "- text1\n",
    "- text2 \n",
    "- llm_zero_shot_similarity_score\n",
    "\"\"\"\n",
    "\n",
    "# Loop through the truth texts and retrieval texts, evaluate each pair using the LLM.\n",
    "results_list = []\n",
    "tokens_list = []\n",
    "i = 0\n",
    "\n",
    "# CAREFUL!! THIS COSTS MONEY!!\n",
    "start_time = time.time()\n",
    "for truth, retrieval in zip(truth_answers, retrieval_answers):\n",
    "\n",
    "    # Construct the zero-shot prompt.\n",
    "    zero_shot_prompt = ZERO_SHOT_PROMPT_TEMPLATE.format(question_number=i, text1=truth, text2=retrieval)\n",
    "\n",
    "    # Generate zero-shot llm as judge score.\n",
    "    temp_df, tempt_df = get_openai_score(LLM_NAME, zero_shot_prompt, TEMPERATURE)\n",
    "    results_list.append(temp_df)\n",
    "    tokens_list.append(tempt_df)\n",
    "    i += 1\n",
    "elapsed_time = time.time() - start_time\n",
    "print(f\"LLM as judge zero-shot took: {elapsed_time} seconds\")\n",
    "\n",
    "# Create a DataFrame from the pandas dataframes.\n",
    "results_df = pd.concat(results_list, ignore_index=True)\n",
    "tokens_df = pd.concat(tokens_list, ignore_index=True)\n",
    "display(results_df.head())\n",
    "tokens_df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 61,
   "id": "f1d76fb3",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Now write a Few Shots Learning Prompt.\n",
    "# Ask for more scores and give explicit examples for each score.\n",
    "text1 = truth\n",
    "text2 = retrieval\n",
    "question_number = QUESTION_NUMBER\n",
    "\n",
    "FEW_SHOT_PROMPT = f\"\"\"\n",
    "For each question_number: {question_number}, calculate the similarity between these two texts: \n",
    "Text1: {text1}, Text2: {text2}.\n",
    "\n",
    "  You'll be given a function grading_function which you'll call for each text pair to submit your reasoning and score for the Correctness and Cmpleteness of the answer. \n",
    "\n",
    "  Below is your grading rubric: \n",
    "\n",
    "- Correctness: If Text2 contains the same key facts as Text1, below are the details for different scores:\n",
    "\n",
    "  - Score = 0: Text2 is completely incorrect, doesn’t mention anything about Text1 or is completely contrary to Text1.\n",
    "\n",
    "      - For example, when Text2 is empty string, or content that’s completely irrelevant, or sorry I don’t know the answer.\n",
    "\n",
    "  - Score = 1: Text2 is hallucinating on any of the facts from Text1.\n",
    "\n",
    "      - Example:\n",
    "\n",
    "          - Text1: \"L2 according to documentation, but in the code it is IP inner product.\"\n",
    "\n",
    "          - Answer: \"Jaccard\"\n",
    "\n",
    "  - Score = 2: If Text2 provides some facts from Text1.\n",
    "\n",
    "      - Example:\n",
    "\n",
    "          - Text1: \"L2 according to documentation, but in the code it is IP inner product.\"\n",
    "\n",
    "          - Text2: “L2\"\n",
    "\n",
    "  - Score = 3: If Text2 correctly answers the question not missing any major facts\n",
    "\n",
    "      - Example:\n",
    "\n",
    "          - Text1: \"L2 according to documentation, but in the code it is IP inner product.\"\n",
    "\n",
    "          - Text2:  \"L2 or IP\"\n",
    "\n",
    "- Completeness: How complete is the answer, does it fully answer all aspects of the question and provide comprehensive explanation and other necessary information. Below are the details for different scores:\n",
    "\n",
    "  - Score 0: If Text2 is completely incorrect, then the completeness is also zero score.\n",
    "\n",
    "  - Score 1: if the answer is correct but too short to fully answer the question, then we can give score 1 for completeness.\n",
    "\n",
    "      - Example:\n",
    "\n",
    "          - Text1: \"The parameters for HNSW are M and efConstruction during construction. During search param is ef.\"\n",
    "\n",
    "          - Text2: \"The parameters for HNSW are M.\"\n",
    "\n",
    "  - Score 2: Text2 is missing description about details. Or is completely missing one minor fact.\n",
    "\n",
    "      - Example:\n",
    "\n",
    "          - Text1: \"The parameters for HNSW are M and efConstruction during construction. During search param is ef.\"\n",
    "\n",
    "          - Text2: \"The parameters for HNSW are M, efConstruction, and ef.\"\n",
    "\n",
    "      - Example:\n",
    "\n",
    "          - Text1: \"The parameters for HNSW are M and efConstruction during construction. During search param is ef.\"\n",
    "\n",
    "          - Text2: \"The parameters for HNSW are M and efConstruction.\"\n",
    "\n",
    "  - Score 3: Text2 is correct, and covers all the main aspects of the question\n",
    "\n",
    "      - Example:\n",
    "\n",
    "          - Text1: \"The parameters for HNSW are M and efConstruction during construction. During search param is ef.\"\n",
    "\n",
    "          - Text2: \"The parameters for HNSW are M and ef during construction and ef during search.\"\n",
    "\n",
    "      - Example:\n",
    "\n",
    "          - Text1: \"The parameters for HNSW are M and efConstruction during construction. During search param is ef.\"\n",
    "\n",
    "          - Text2: \"The parameters for HNSW are ef during search.  M and ef during construction.\"\n",
    "\n",
    "- Then final rating:\n",
    "\n",
    "    - llm_few_shot_similarity_score: 60% correctness + 40% completeness\n",
    "\n",
    "Output JSON fields:\n",
    "- question_number\n",
    "- llm_few_shot_similarity_score\n",
    "\"\"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 62,
   "id": "bfdc3a81",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "LLM as judge few-shot took: 3.3466129302978516 seconds\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>question_number</th>\n",
       "      <th>llm_few_shot_similarity_score</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0</td>\n",
       "      <td>2.4</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   question_number  llm_few_shot_similarity_score\n",
       "0                0                            2.4"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>prompt_tokens</th>\n",
       "      <th>completion_tokens</th>\n",
       "      <th>total_tokens</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>941</td>\n",
       "      <td>28</td>\n",
       "      <td>969</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   prompt_tokens  completion_tokens  total_tokens\n",
       "0            941                 28           969"
      ]
     },
     "execution_count": 62,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Test few-shot LLM as judge on a single question.\n",
    "# Doc Openai function calling: https://platform.openai.com/docs/guides/function-calling\n",
    "\n",
    "# # CAREFUL!! THIS COSTS MONEY!!\n",
    "# start_time = time.time()\n",
    "# result2_df, token_df = get_openai_score(LLM_NAME, FEW_SHOT_PROMPT, TEMPERATURE)\n",
    "# elapsed_time = time.time() - start_time\n",
    "# print(f\"LLM as judge few-shot took: {elapsed_time} seconds\")\n",
    "\n",
    "display(result2_df.head())  # score = 2.4\n",
    "token_df.head()\n",
    "\n",
    "# question_number\tllm_few_shot_similarity_score\n",
    "# 0\t0\t2.4\n",
    "# prompt_tokens\tcompletion_tokens\ttotal_tokens\n",
    "# 0\t941\t28\t969"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 75,
   "id": "f68c932e",
   "metadata": {},
   "outputs": [],
   "source": [
    "FEW_SHOT_PROMPT_TEMPLATE = \"\"\"For each question_number: {question_number}, calculate the similarity between these two texts: \n",
    "Text1: {text1}, Text2: {text2}.\n",
    "\n",
    "  You'll be given a function grading_function which you'll call for each text pair to submit your reasoning and score for the Correctness and Cmpleteness of the answer. \n",
    "\n",
    "  Below is your grading rubric: \n",
    "\n",
    "- Correctness: If Text2 contains the same key facts as Text1, below are the details for different scores:\n",
    "\n",
    "  - Score = 0: Text2 is completely incorrect, doesn’t mention anything about Text1 or is completely contrary to Text1.\n",
    "\n",
    "      - For example, when Text2 is empty string, or content that’s completely irrelevant, or sorry I don’t know the answer.\n",
    "\n",
    "  - Score = 1: Text2 is hallucinating on any of the facts from Text1.\n",
    "\n",
    "      - Example:\n",
    "\n",
    "          - Text1: \"L2 according to documentation, but in the code it is IP inner product.\"\n",
    "\n",
    "          - Answer: \"Jaccard\"\n",
    "\n",
    "  - Score = 2: If Text2 provides some facts from Text1.\n",
    "\n",
    "      - Example:\n",
    "\n",
    "          - Text1: \"L2 according to documentation, but in the code it is IP inner product.\"\n",
    "\n",
    "          - Text2: “L2\"\n",
    "\n",
    "  - Score = 3: If Text2 correctly answers the question not missing any major facts\n",
    "\n",
    "      - Example:\n",
    "\n",
    "          - Text1: \"L2 according to documentation, but in the code it is IP inner product.\"\n",
    "\n",
    "          - Text2:  \"L2 or IP\"\n",
    "\n",
    "- Completeness: How complete is the answer, does it fully answer all aspects of the question and provide comprehensive explanation and other necessary information. Below are the details for different scores:\n",
    "\n",
    "  - Score 0: If Text2 is completely incorrect, then the completeness is also zero score.\n",
    "\n",
    "  - Score 1: if the answer is correct but too short to fully answer the question, then we can give score 1 for completeness.\n",
    "\n",
    "      - Example:\n",
    "\n",
    "          - Text1: \"The parameters for HNSW are M and efConstruction during construction. During search param is ef.\"\n",
    "\n",
    "          - Text2: \"The parameters for HNSW are M.\"\n",
    "\n",
    "  - Score 2: Text2 is missing description about details. Or is completely missing one minor fact.\n",
    "\n",
    "      - Example:\n",
    "\n",
    "          - Text1: \"The parameters for HNSW are M and efConstruction during construction. During search param is ef.\"\n",
    "\n",
    "          - Text2: \"The parameters for HNSW are M, efConstruction, and ef.\"\n",
    "\n",
    "      - Example:\n",
    "\n",
    "          - Text1: \"The parameters for HNSW are M and efConstruction during construction. During search param is ef.\"\n",
    "\n",
    "          - Text2: \"The parameters for HNSW are M and efConstruction.\"\n",
    "\n",
    "  - Score 3: Text2 is correct, and covers all the main aspects of the question\n",
    "\n",
    "      - Example:\n",
    "\n",
    "          - Text1: \"The parameters for HNSW are M and efConstruction during construction. During search param is ef.\"\n",
    "\n",
    "          - Text2: \"The parameters for HNSW are M and ef during construction and ef during search.\"\n",
    "\n",
    "      - Example:\n",
    "\n",
    "          - Text1: \"The parameters for HNSW are M and efConstruction during construction. During search param is ef.\"\n",
    "\n",
    "          - Text2: \"The parameters for HNSW are ef during search.  M and ef during construction.\"\n",
    "\n",
    "- Then final rating:\n",
    "\n",
    "    - llm_few_shot_similarity_score: 60% correctness + 40% completeness\n",
    "\n",
    "Output JSON fields:\n",
    "- question_number\n",
    "- llm_few_shot_similarity_score\n",
    "\"\"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 79,
   "id": "35df0c32",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "LLM as judge few-shot took: 16.285736083984375 seconds\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>question_number</th>\n",
       "      <th>llm_few_shot_similarity_score</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0</td>\n",
       "      <td>2.4</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1</td>\n",
       "      <td>2.4</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>2</td>\n",
       "      <td>1.6</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>3</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   question_number  llm_few_shot_similarity_score\n",
       "0                0                            2.4\n",
       "1                1                            2.4\n",
       "2                2                            1.6\n",
       "3                3                            0.0"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>prompt_tokens</th>\n",
       "      <th>completion_tokens</th>\n",
       "      <th>total_tokens</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>940</td>\n",
       "      <td>28</td>\n",
       "      <td>968</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>895</td>\n",
       "      <td>28</td>\n",
       "      <td>923</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>890</td>\n",
       "      <td>28</td>\n",
       "      <td>918</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>972</td>\n",
       "      <td>28</td>\n",
       "      <td>1000</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   prompt_tokens  completion_tokens  total_tokens\n",
       "0            940                 28           968\n",
       "1            895                 28           923\n",
       "2            890                 28           918\n",
       "3            972                 28          1000"
      ]
     },
     "execution_count": 79,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Use few-shot LLM as judge on all question/retrieval pairs.\n",
    "# Drop token counting for now.\n",
    "\n",
    "# Loop through the truth texts and retrieval texts, evaluate each pair using the LLM.\n",
    "results2_list = []\n",
    "tokens_list = []\n",
    "i = 0\n",
    "\n",
    "# CAREFUL!! THIS COSTS MONEY!!\n",
    "start_time = time.time()\n",
    "for truth, retrieval in zip(truth_answers, retrieval_answers):\n",
    "\n",
    "    # Construct the few-shot prompt.\n",
    "    few_shot_prompt = FEW_SHOT_PROMPT_TEMPLATE.format(question_number=i, text1=truth, text2=retrieval)\n",
    "    # print(few_shot_prompt[:50])\n",
    "\n",
    "    # Generate zero-shot llm as judge score.\n",
    "    temp_df, tempt_df = get_openai_score(LLM_NAME, few_shot_prompt, TEMPERATURE)\n",
    "    results2_list.append(temp_df)\n",
    "    tokens_list.append(tempt_df)\n",
    "    i += 1\n",
    "elapsed_time = time.time() - start_time\n",
    "print(f\"LLM as judge few-shot took: {elapsed_time} seconds\")\n",
    "\n",
    "# Create a DataFrame from the pandas dataframes.\n",
    "results2_df = pd.concat(results2_list, ignore_index=True)\n",
    "tokens_df = pd.concat(tokens_list, ignore_index=True)\n",
    "display(results2_df.head())\n",
    "tokens_df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 80,
   "id": "54233ab1",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Convenience function to get sources from retrieved result object.\n",
    "def get_references(result):\n",
    "    sources = []\n",
    "\n",
    "    for r in result:\n",
    "        sources.append(r[0][0]['entity']['source'])\n",
    "\n",
    "    return sources\n",
    "\n",
    "# Define a binary score whether or not the retrieval source matches ground truth source.\n",
    "def get_source_binary_score(truth_uris, retrieved_uris):\n",
    "    \"\"\"\n",
    "    Returns 1 if the 0th retrieved uri matches the truth URI, else 0.\n",
    "    \"\"\"\n",
    "    retrieval_scores = []\n",
    "    for tr, rr in zip(truth_uris, retrieved_uris):\n",
    "        # https://en.wikipedia.org/wiki/New_York_City\n",
    "        # Parse out the last part of the URI.\n",
    "        retrieval_score = 1 if tr.split(\"/\")[-1] == rr.split(\"/\")[-1] else 0\n",
    "        retrieval_scores.append(retrieval_score)\n",
    "\n",
    "    return retrieval_scores"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 81,
   "id": "ca80ba7c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "uris: 4, sources: 4\n",
      "Binary score for retrieval = [1, 1, 1, 0]\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>question_number</th>\n",
       "      <th>text1</th>\n",
       "      <th>text2</th>\n",
       "      <th>llm_zero_shot_similarity_score</th>\n",
       "      <th>binary_source_score</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0</td>\n",
       "      <td>- M: maximum degree of nodes in a layer of the...</td>\n",
       "      <td>performance, HNSW limits the maximum degree of...</td>\n",
       "      <td>2</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1</td>\n",
       "      <td>M=16, efConstruction=32, ef=32</td>\n",
       "      <td>Metrics. Vector Index¶ FLAT IVF_FLAT IVF_SQ8 I...</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>2</td>\n",
       "      <td>Trick answer: IP inner product, not yet update...</td>\n",
       "      <td>metric_type=) The attributes of collection can...</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>3</td>\n",
       "      <td>In the 1600’s, the Dutch planted a trading pos...</td>\n",
       "      <td>def name(self): return self._name @property de...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   question_number                                              text1  \\\n",
       "0                0  - M: maximum degree of nodes in a layer of the...   \n",
       "1                1                     M=16, efConstruction=32, ef=32   \n",
       "2                2  Trick answer: IP inner product, not yet update...   \n",
       "3                3  In the 1600’s, the Dutch planted a trading pos...   \n",
       "\n",
       "                                               text2  \\\n",
       "0  performance, HNSW limits the maximum degree of...   \n",
       "1  Metrics. Vector Index¶ FLAT IVF_FLAT IVF_SQ8 I...   \n",
       "2  metric_type=) The attributes of collection can...   \n",
       "3  def name(self): return self._name @property de...   \n",
       "\n",
       "   llm_zero_shot_similarity_score  binary_source_score  \n",
       "0                               2                    1  \n",
       "1                               0                    1  \n",
       "2                               0                    1  \n",
       "3                               0                    0  "
      ]
     },
     "execution_count": 81,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Calculate a rough, binary score if retrieval source matches the ground truth source.\n",
    "\n",
    "# Get sources from retrieved results.\n",
    "retrieved_uris = get_references(retrieved_results)\n",
    "print(f\"uris: {len(truth_uris)}, sources: {len(retrieved_uris)}\")\n",
    "\n",
    "# Calculate a rough, binary score if 0th retrieval source matches the ground truth source.\n",
    "binary_scores = get_source_binary_score(truth_uris, retrieved_uris)\n",
    "print(f\"Binary score for retrieval = {binary_scores}\")\n",
    "\n",
    "# Append the binary sources score to the eval results dataframe.\n",
    "results_df['binary_source_score'] = binary_scores\n",
    "results_df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 82,
   "id": "d4b51ded",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>question_number</th>\n",
       "      <th>text1</th>\n",
       "      <th>text2</th>\n",
       "      <th>llm_zero_shot_similarity_score</th>\n",
       "      <th>binary_source_score</th>\n",
       "      <th>llm_few_shot_similarity_score</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0</td>\n",
       "      <td>- M: maximum degree of nodes in a layer of the...</td>\n",
       "      <td>performance, HNSW limits the maximum degree of...</td>\n",
       "      <td>2</td>\n",
       "      <td>1</td>\n",
       "      <td>2.4</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1</td>\n",
       "      <td>M=16, efConstruction=32, ef=32</td>\n",
       "      <td>Metrics. Vector Index¶ FLAT IVF_FLAT IVF_SQ8 I...</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>2.4</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>2</td>\n",
       "      <td>Trick answer: IP inner product, not yet update...</td>\n",
       "      <td>metric_type=) The attributes of collection can...</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>1.6</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>3</td>\n",
       "      <td>In the 1600’s, the Dutch planted a trading pos...</td>\n",
       "      <td>def name(self): return self._name @property de...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   question_number                                              text1  \\\n",
       "0                0  - M: maximum degree of nodes in a layer of the...   \n",
       "1                1                     M=16, efConstruction=32, ef=32   \n",
       "2                2  Trick answer: IP inner product, not yet update...   \n",
       "3                3  In the 1600’s, the Dutch planted a trading pos...   \n",
       "\n",
       "                                               text2  \\\n",
       "0  performance, HNSW limits the maximum degree of...   \n",
       "1  Metrics. Vector Index¶ FLAT IVF_FLAT IVF_SQ8 I...   \n",
       "2  metric_type=) The attributes of collection can...   \n",
       "3  def name(self): return self._name @property de...   \n",
       "\n",
       "   llm_zero_shot_similarity_score  binary_source_score  \\\n",
       "0                               2                    1   \n",
       "1                               0                    1   \n",
       "2                               0                    1   \n",
       "3                               0                    0   \n",
       "\n",
       "   llm_few_shot_similarity_score  \n",
       "0                            2.4  \n",
       "1                            2.4  \n",
       "2                            1.6  \n",
       "3                            0.0  "
      ]
     },
     "execution_count": 82,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Append the few-shot scores to the eval results dataframe.\n",
    "results_df['llm_few_shot_similarity_score'] = results2_df['llm_few_shot_similarity_score']\n",
    "results_df.head()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 83,
   "id": "cc46792b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>question_number</th>\n",
       "      <th>text1</th>\n",
       "      <th>text2</th>\n",
       "      <th>llm_zero_shot_similarity_score</th>\n",
       "      <th>binary_source_score</th>\n",
       "      <th>llm_few_shot_similarity_score</th>\n",
       "      <th>final_score</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0</td>\n",
       "      <td>- M: maximum degree of nodes in a layer of the...</td>\n",
       "      <td>performance, HNSW limits the maximum degree of...</td>\n",
       "      <td>2</td>\n",
       "      <td>1</td>\n",
       "      <td>2.4</td>\n",
       "      <td>1.7</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1</td>\n",
       "      <td>M=16, efConstruction=32, ef=32</td>\n",
       "      <td>Metrics. Vector Index¶ FLAT IVF_FLAT IVF_SQ8 I...</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>2.4</td>\n",
       "      <td>1.7</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>2</td>\n",
       "      <td>Trick answer: IP inner product, not yet update...</td>\n",
       "      <td>metric_type=) The attributes of collection can...</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>1.6</td>\n",
       "      <td>1.3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>3</td>\n",
       "      <td>In the 1600’s, the Dutch planted a trading pos...</td>\n",
       "      <td>def name(self): return self._name @property de...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   question_number                                              text1  \\\n",
       "0                0  - M: maximum degree of nodes in a layer of the...   \n",
       "1                1                     M=16, efConstruction=32, ef=32   \n",
       "2                2  Trick answer: IP inner product, not yet update...   \n",
       "3                3  In the 1600’s, the Dutch planted a trading pos...   \n",
       "\n",
       "                                               text2  \\\n",
       "0  performance, HNSW limits the maximum degree of...   \n",
       "1  Metrics. Vector Index¶ FLAT IVF_FLAT IVF_SQ8 I...   \n",
       "2  metric_type=) The attributes of collection can...   \n",
       "3  def name(self): return self._name @property de...   \n",
       "\n",
       "   llm_zero_shot_similarity_score  binary_source_score  \\\n",
       "0                               2                    1   \n",
       "1                               0                    1   \n",
       "2                               0                    1   \n",
       "3                               0                    0   \n",
       "\n",
       "   llm_few_shot_similarity_score  final_score  \n",
       "0                            2.4          1.7  \n",
       "1                            2.4          1.7  \n",
       "2                            1.6          1.3  \n",
       "3                            0.0          0.0  "
      ]
     },
     "execution_count": 83,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Drop the zero_shot score, use few_shot score instead.\n",
    "\n",
    "# Calculate a final eval score as a weighted average of the binary source score and the few-shot score.\n",
    "results_df['final_score'] = (results_df['binary_source_score'] + results_df['llm_few_shot_similarity_score']) / 2\n",
    "results_df.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bd6060ce",
   "metadata": {},
   "source": [
    "## Use an LLM to Generate a chat response to the user's question using the Retrieved Context.\n",
    "\n",
    "Below, we'll use an open, very tiny generative AI model, or LLM, available on HuggingFace.  Many demos use OpenAI as the LLM choice instead."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0ff76685",
   "metadata": {},
   "outputs": [],
   "source": [
    "# USING A TINY OSS LLM: ASK THE SAME QUESTION WITH RETRIEVED CONTEXT.\n",
    "\n",
    "# Define the question and context\n",
    "context_slice = context[0][111:257]\n",
    "# Short prompt for tiny LLM\n",
    "short_prompt = f\"\"\"Explain more using the Context or say \"I don't know\".\n",
    "Context: {context_slice}\n",
    "\"\"\"\n",
    "\n",
    "# Set the encoding parameters\n",
    "encoding_parameters = {\n",
    "    \"return_tensors\": \"pt\",  # Return PyTorch tensors\n",
    "    \"max_length\": MAX_SEQ_LENGTH,  # Maximum length for the encoded tokens\n",
    "    \"truncation\": True,  # Enable truncation to avoid sequences longer than max_length\n",
    "}\n",
    "\n",
    "# Encode the inputs for question-answering\n",
    "inputs = tokenizer.encode_plus(\n",
    "    SAMPLE_QUESTION,  # The question to be asked\n",
    "    context_slice,  # The context in which the question is asked\n",
    "    # Replace context with a short prompt\n",
    "    # short_prompt,\n",
    "    **encoding_parameters  # The encoding parameters\n",
    ")\n",
    "\n",
    "# Generate the answer using the model\n",
    "output = model(**inputs)\n",
    "start_index = torch.argmax(output.start_logits)\n",
    "end_index = torch.argmax(output.end_logits) + 1\n",
    "answer = tokenizer.convert_tokens_to_string(\n",
    "    tokenizer.convert_ids_to_tokens(inputs[\"input_ids\"][0][start_index:end_index]))\n",
    "\n",
    "# Print the generated answer\n",
    "print(\"Generated Answer:\", answer)\n",
    "\n",
    "# Better answer but incomplete."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d0e81e68",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Drop collection\n",
    "utility.drop_collection(COLLECTION_NAME)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c777937e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Props to Sebastian Raschka for this handy watermark.\n",
    "# !pip install watermark\n",
    "\n",
    "%load_ext watermark\n",
    "%watermark -a 'Christy Bergman' -v -p torch,transformers,sentence_transformers,pymilvus,langchain,openai --conda"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
